#coding=utf-8
from cs231n.classifiers import linear_svm
import numpy as np

class LinearClassifier(object):
    def __init__(self):
        self.W = None
    def train(self, X, y, learning_rate=1e-3, reg=1e-5, num_iters=100, batch_size=200, verbose=False):
        '''
        使用SGD训练线性分类器
        Inputs:输入
        - X: 训练数据，维度为[N,D]
        - y: 训练标签, y[i]=c
        - learning_rate: 优化器的学习率
        - reg: 正则化系数
        - num_iters: 训练的迭代次数
        - batch_size: 训练batch
        - verbose: 是否打印优化过程
        Outputs:
        在每一轮训练中，返回一个损失函数的列表
        '''
        num_train, dim = X.shape
        num_classes = np.max(y) + 1 #分类数是0到c-1
        if self.W is None:
            self.W = 0.001 * np.random.rand(dim, num_classes)
        loss_history = []
        it = 0
        for it in range(num_iters):
            batch_idx = np.random.choice(num_train, batch_size, replace=True) #replacement代表的意思是抽样之后还放不放回去，如果是False的话，那么出来的三个数都不一样，如果是True的话， 有可能会出现重复的，因为前面的抽的放回去了
            X_batch = X[batch_idx]
            y_batch = y[batch_idx]
            loss, grad = linear_svm.svm_loss_vectorized(self.W, X_batch, y_batch, reg)
            loss_history.append(loss)
            self.W += - learning_rate * grad
            if verbose and it % 100 == 0:
                print('iteration %d / %d: loss %f' % (it, num_iters, loss))
        return loss_history
    def predict(self, X):
        '''
        使用训练好的线性分类器的W去预测
        - Inputs: 输入数据维度为D*N
        - y_pred： 对N个输入数据预测一个标签y_pred
        '''
        y_pred = np.zeros(X.shape[1])
        scores = X.dot(self.W)
        y_pred = np.argmax(scores, axis=1)
        return y_pred
    def loss(self, X_batch, y_batch, reg):
        '''
        计算损失函数和它的导数
        子类会重写这个函数
        '''
        pass

class LinearSVM(LinearClassifier):
    def loss(self, X_batch, y_batch, reg):
        return linear_svm.svm_loss_vectorized(self.W, X_batch, y_batch, reg)

class Softmax(LinearClassifier):
    def loss(self, X_batch, y_batch, reg):
        return linear_svm.softmax_loss_vectorized(self.W, X_batch, y_batch, reg)





